{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "import torch\n",
    "import argparse\n",
    "import matplotlib.pyplot as plt\n",
    "import sys\n",
    "sys.path.append('../')\n",
    "from models.soft_shift_net.innerSoftShiftTriple import InnerSoftShiftTriple\n",
    "#from models.accelerated_shift_net.accelerated_InnerShiftTriple import AcceleratedInnerShiftTriple\n",
    "from options.train_options import TrainOptions \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CREATE DEFAULT OPTIONS TO INITIALIZE THE SHIFTMODEL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataroot = '/mnt/hdd2/AIM/DAGM/Class4_def/' # ENTER HERE THE PATH YOU WANT TO USE AS DATAROOT\n",
    "options = '--dataroot {}'.format(dataroot).split(' ')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_parser(options=None):\n",
    "    parser = TrainOptions()\n",
    "    parser.parse(options=options)\n",
    "    return parser"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------- Options ---------------\n",
      "           add_mask2input: False                         \n",
      "                batchSize: 1                             \n",
      "                    beta1: 0.5                           \n",
      "               bottleneck: 512                           \n",
      "          checkpoints_dir: ./log                         \n",
      "                constrain: MSE                           \n",
      "           continue_train: False                         \n",
      "                 dataroot: /mnt/hdd2/AIM/DAGM/Class4_def/\t[default: ./datasets/Paris/train]\n",
      "             dataset_mode: aligned                       \n",
      "             display_freq: 10                            \n",
      "               display_id: 1                             \n",
      "            display_ncols: 4                             \n",
      "             display_port: 8097                          \n",
      "           display_server: http://localhost              \n",
      "display_single_pane_ncols: 0                             \n",
      "          display_winsize: 256                           \n",
      "              epoch_count: 1                             \n",
      "                 fineSize: 256                           \n",
      "               fixed_mask: 1                             \n",
      "                 gan_type: vanilla                       \n",
      "               gan_weight: 0.2                           \n",
      "                gp_lambda: 10.0                          \n",
      "                  gpu_ids: 0                             \n",
      "                init_gain: 0.02                          \n",
      "                init_type: normal                        \n",
      "                 input_nc: 3                             \n",
      "                  isTrain: True                          \t[default: None]\n",
      "                 lambda_A: 100                           \n",
      "                 loadSize: 350                           \n",
      "                       lr: 0.0002                        \n",
      "           lr_decay_iters: 50                            \n",
      "                lr_policy: lambda                        \n",
      "            mask_sub_type: island                        \n",
      "               mask_thred: 1                             \n",
      "                mask_type: random                        \n",
      "         max_dataset_size: inf                           \n",
      "                    model: accelerated_shiftnet          \n",
      "                 nThreads: 2                             \n",
      "               n_layers_D: 3                             \n",
      "                     name:                               \n",
      "                  ncritic: 5                             \n",
      "                      ndf: 64                            \n",
      "                      ngf: 64                            \n",
      "                    niter: 10000000                      \n",
      "              niter_decay: 0                             \n",
      "                  no_flip: False                         \n",
      "                  no_html: False                         \n",
      "                     norm: instance                      \n",
      "             only_lastest: True                          \n",
      "                output_nc: 3                             \n",
      "                  overlap: 4                             \n",
      "                    phase: train                         \n",
      "               print_freq: 50                            \n",
      "           resize_or_crop: resize_and_crop               \n",
      "          save_epoch_freq: 2                             \n",
      "         save_latest_freq: 5000                          \n",
      "           serial_batches: False                         \n",
      "                 shift_sz: 1                             \n",
      "                     skip: 0                             \n",
      "                 strength: 1                             \n",
      "                   stride: 1                             \n",
      "                   suffix:                               \n",
      "                threshold: 0.3125                        \n",
      "            triple_weight: 1                             \n",
      "         update_html_freq: 1000                          \n",
      "              use_dropout: False                         \n",
      "                  verbose: False                         \n",
      "              which_epoch: latest                        \n",
      "         which_model_netD: densenet                      \n",
      "         which_model_netG: acc_unet_shift_triple         \n",
      "----------------- End -------------------\n"
     ]
    }
   ],
   "source": [
    "parser = get_parser(options=options)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CREATE INNER_SHIFT_TRIPLE LAYER"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "#from models.InnerShiftTriple import InnerShiftTriple"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt = parser.opt\n",
    "#opt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "inner_shift_triple = InnerSoftShiftTriple(opt.threshold, opt.fixed_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "InnerSoftShiftTriple(threshold: 0.3125 ,triple_weight 1)"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inner_shift_triple.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# EVALUE SPEED FORWARD"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### THE SIZE OF THE INPUT TENSOR IS (BATCH_SIZE, 256 * 2 (former | latter), 32, 32). LET CREATE A RANDOM TENSORS AND EVALUTE ITS FORWARD FIRST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "torch.cuda.is_available()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### NOW WE NEED TO SET UP THE MASK"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7f1979918588>"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQYAAAD8CAYAAACVSwr3AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADJBJREFUeJzt3H+s3XV9x/Hna7TUDF2E4ZpSmoGm+wOXrJIbJJEYFzKBZknxHwJ/SGdI6h+YaOL+qPqH/GPilqmZyUZSI7EuTiRTQ/9gU2xMjH+oFFKBwtCKJbQrdA6CZCYV8L0/7rd47Pve3tt7z7nn3Pl8JDfnez/ne+5595vmme/5mapCkkb9wbQHkDR7DIOkxjBIagyDpMYwSGoMg6RmYmFIcmOSp5IcTbJ3UvcjafwyifcxJLkA+AnwV8Bx4CHgtqp6Yux3JmnsJnXGcA1wtKqerqpfA/cCuyZ0X5LGbMOE/u5W4NmR348D71xs5wuzqd7ARRMaRRLAy7z4i6p6y3L2nVQYlpRkD7AH4A38Ie/M9dMaRfq98J36t2eWu++kHkqcALaN/H75sPa6qtpXVXNVNbeRTRMaQ9JKTCoMDwHbk1yZ5ELgVuDAhO5L0phN5KFEVb2a5EPAt4ALgHuq6sgk7kvS+E3sOYaqegB4YFJ/X9Lk+M5HSY1hkNQYBkmNYZDUGAZJjWGQ1BgGSY1hkNQYBkmNYZDUGAZJjWGQ1BgGSY1hkNQYBkmNYZDUGAZJjWGQ1BgGSY1hkNQYBkmNYZDUGAZJjWGQ1BgGSY1hkNQYBkmNYZDUGAZJjWGQ1BgGSY1hkNQYBkmNYZDUbFjNjZMcA14GXgNeraq5JJcAXwOuAI4Bt1TVi6sbU9JaGscZw19W1Y6qmht+3wscrKrtwMHhd0nryCQeSuwC9g/b+4GbJ3AfkiZotWEo4NtJHk6yZ1jbXFUnh+3ngM0L3TDJniSHkhx6hdOrHEPSOK3qOQbguqo6keRPgAeT/OfolVVVSWqhG1bVPmAfwB/lkgX3kTQdqzpjqKoTw+Up4JvANcDzSbYADJenVjukpLW14jAkuSjJm85sA+8FHgcOALuH3XYD9692SElrazUPJTYD30xy5u/8a1X9R5KHgPuS3AE8A9yy+jElraUVh6Gqngb+YoH1/wGuX81QkqbLdz5KagyDpMYwSGoMg6TGMEhqDIOkxjBIagyDpMYwSGoMg6TGMEhqDIOkxjBIagyDpMYwSGoMg6TGMEhqDIOkxjBIagyDpMYwSGoMg6TGMEhqDIOkxjBIagyDpMYwSGoMg6TGMEhqDIOkxjBIagyDpMYwSGo2LLVDknuAvwZOVdWfD2uXAF8DrgCOAbdU1YtJAvwjsBP4FfA3VfXIZEbX2b71X4enPcJE3HDZjmmP8HtnOWcMXwJuPGttL3CwqrYDB4ffAW4Ctg8/e4C7xzOmpLW0ZBiq6nvAC2ct7wL2D9v7gZtH1r9c834AvDnJlnENK2ltrPQ5hs1VdXLYfg7YPGxvBZ4d2e/4sCZpHVn1k49VVUCd7+2S7ElyKMmhVzi92jEkjdFKw/D8mYcIw+WpYf0EsG1kv8uHtaaq9lXVXFXNbWTTCseQNAkrDcMBYPewvRu4f2T99sy7Fnhp5CGHpHViOS9XfhV4D3BpkuPAJ4FPA/cluQN4Brhl2P0B5l+qPMr8y5UfmMDMkiZsyTBU1W2LXHX9AvsWcOdqh5I0Xb7zUVJjGCQ1hkFSYxgkNYZBUmMYJDWGQVJjGCQ1hkFSYxgkNYZBUmMYJDWGQVJjGCQ1hkFSYxgkNYZBUmMYJDWGQVJjGCQ1hkFSYxgkNYZBUmMYJDWGQVJjGCQ1hkFSYxgkNYZBUmMYJDWGQVJjGCQ1hkFSYxgkNUuGIck9SU4leXxk7a4kJ5IcHn52jlz3sSRHkzyV5IZJDS5pcpZzxvAl4MYF1j9XVTuGnwcAklwF3Aq8fbjNPye5YFzDSlobS4ahqr4HvLDMv7cLuLeqTlfVz4GjwDWrmE/SFKzmOYYPJXl0eKhx8bC2FXh2ZJ/jw1qTZE+SQ0kOvcLpVYwhadxWGoa7gbcBO4CTwGfO9w9U1b6qmququY1sWuEYkiZhRWGoquer6rWq+g3wBX77cOEEsG1k18uHNUnryIrCkGTLyK/vA868YnEAuDXJpiRXAtuBH61uRElrbcNSOyT5KvAe4NIkx4FPAu9JsgMo4BjwQYCqOpLkPuAJ4FXgzqp6bTKjS5qUJcNQVbctsPzFc+z/KeBTqxlK0nT5zkdJjWGQ1Cz5UELrxw2X7Zj2CPp/wjMGSY1hkNQYBkmNYZDUGAZJjWGQ1BgGSY1hkNQYBkmNYZDUGAZJjWGQ1BgGSY1hkNQYBkmNYZDUGAZJjWGQ1BgGSY1hkNQYBkmNYZDUGAZJjWGQ1BgGSY1hkNQYBkmNYZDUGAZJjWGQ1CwZhiTbknw3yRNJjiT58LB+SZIHk/x0uLx4WE+Szyc5muTRJFdP+h8habyWc8bwKvDRqroKuBa4M8lVwF7gYFVtBw4OvwPcBGwffvYAd499akkTtWQYqupkVT0ybL8MPAlsBXYB+4fd9gM3D9u7gC/XvB8Ab06yZeyTS5qY83qOIckVwDuAHwKbq+rkcNVzwOZheyvw7MjNjg9rktaJZYchyRuBrwMfqapfjl5XVQXU+dxxkj1JDiU59Aqnz+emkiZsWWFIspH5KHylqr4xLD9/5iHCcHlqWD8BbBu5+eXD2u+oqn1VNVdVcxvZtNL5JU3Acl6VCPBF4Mmq+uzIVQeA3cP2buD+kfXbh1cnrgVeGnnIIWkd2LCMfd4FvB94LMnhYe3jwKeB+5LcATwD3DJc9wCwEzgK/Ar4wFgnljRxS4ahqr4PZJGrr19g/wLuXOVckqbIdz5KagyDpMYwSGoMg6TGMEhqDIOkxjBIagyDpMYwSGoMg6TGMEhqDIOkxjBIagyDpMYwSGoMg6TGMEhqDIOkxjBIagyDpMYwSGoMg6TGMEhqDIOkxjBIagyDpMYwSGoMg6TGMEhqDIOkxjBIagyDpMYwSGoMg6RmyTAk2Zbku0meSHIkyYeH9buSnEhyePjZOXKbjyU5muSpJDdM8h8gafw2LGOfV4GPVtUjSd4EPJzkweG6z1XVP4zunOQq4Fbg7cBlwHeS/FlVvTbOwSVNzpJnDFV1sqoeGbZfBp4Etp7jJruAe6vqdFX9HDgKXDOOYSWtjfN6jiHJFcA7gB8OSx9K8miSe5JcPKxtBZ4dudlxFghJkj1JDiU59Aqnz3twSZOz7DAkeSPwdeAjVfVL4G7gbcAO4CTwmfO546raV1VzVTW3kU3nc1NJE7asMCTZyHwUvlJV3wCoquer6rWq+g3wBX77cOEEsG3k5pcPa5LWieW8KhHgi8CTVfXZkfUtI7u9D3h82D4A3JpkU5Irge3Aj8Y3sqRJW86rEu8C3g88luTwsPZx4LYkO4ACjgEfBKiqI0nuA55g/hWNO31FQlpfUlXTnoEk/w38L/CLac+yDJeyPuaE9TOrc47fQrP+aVW9ZTk3nokwACQ5VFVz055jKetlTlg/szrn+K12Vt8SLakxDJKaWQrDvmkPsEzrZU5YP7M65/itataZeY5B0uyYpTMGSTNi6mFIcuPw8eyjSfZOe56zJTmW5LHho+WHhrVLkjyY5KfD5cVL/Z0JzHVPklNJHh9ZW3CuzPv8cIwfTXL1DMw6cx/bP8dXDMzUcV2Tr0Koqqn9ABcAPwPeClwI/Bi4apozLTDjMeDSs9b+Htg7bO8F/m4Kc70buBp4fKm5gJ3AvwMBrgV+OAOz3gX87QL7XjX8P9gEXDn8/7hgjebcAlw9bL8J+Mkwz0wd13PMObZjOu0zhmuAo1X1dFX9GriX+Y9tz7pdwP5hez9w81oPUFXfA144a3mxuXYBX655PwDefNZb2idqkVkXM7WP7dfiXzEwU8f1HHMu5ryP6bTDsKyPaE9ZAd9O8nCSPcPa5qo6OWw/B2yezmjNYnPN6nFe8cf2J+2srxiY2eM6zq9CGDXtMKwH11XV1cBNwJ1J3j16Zc2fq83cSzuzOteIVX1sf5IW+IqB183ScR33VyGMmnYYZv4j2lV1Yrg8BXyT+VOw58+cMg6Xp6Y34e9YbK6ZO841ox/bX+grBpjB4zrpr0KYdhgeArYnuTLJhcx/V+SBKc/0uiQXDd9zSZKLgPcy//HyA8DuYbfdwP3TmbBZbK4DwO3Ds+jXAi+NnBpPxSx+bH+xrxhgxo7rYnOO9ZiuxbOoSzzDupP5Z1V/Bnxi2vOcNdtbmX8298fAkTPzAX8MHAR+CnwHuGQKs32V+dPFV5h/zHjHYnMx/6z5Pw3H+DFgbgZm/ZdhlkeH/7hbRvb/xDDrU8BNazjndcw/THgUODz87Jy143qOOcd2TH3no6Rm2g8lJM0gwyCpMQySGsMgqTEMkhrDIKkxDJIawyCp+T/ngnElFjH4JwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "c, h, w = (1, 256, 256)\n",
    "hh = h//2\n",
    "wh = w//2\n",
    "hm_size = 32\n",
    "mask = np.zeros((1, c, h, w))\n",
    "mask[..., hh - hm_size:hh + hm_size, wh - hm_size:wh + hm_size] = 1\n",
    "#mask[..., h - hh:, :] = 1\n",
    "mask_global=torch.ByteTensor(mask).cuda()#.cpu()\n",
    "plt.imshow(np.squeeze(mask))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0, 0, 0,  ..., 0, 0, 0],\n",
       "        [0, 0, 0,  ..., 0, 0, 0],\n",
       "        [0, 0, 0,  ..., 0, 0, 0],\n",
       "        ...,\n",
       "        [0, 0, 0,  ..., 0, 0, 0],\n",
       "        [0, 0, 0,  ..., 0, 0, 0],\n",
       "        [0, 0, 0,  ..., 0, 0, 0]], device='cuda:0', dtype=torch.uint8)"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inner_shift_triple.set_mask(mask_global=mask_global, threshold=opt.threshold, layer_to_last=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_np = np.random.normal(0, 1, (1, 512, 32, 32))\n",
    "x_tr = torch.FloatTensor(x_np)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4.05 ms ± 846 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
     ]
    }
   ],
   "source": [
    "%timeit output = inner_shift_triple(x_tr.cuda())\n",
    "#output = inner_shift_triple(x_tr.cuda())\n",
    "#flag, indexes, ind_lst = inner_shift_triple(x_tr.cuda())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = tuple((np.where(flag_n == 1), f1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transition_matrx[idx] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(transition_matrx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tmp = tmp[:, cp]\n",
    "tmp.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cp = np.where(flag == 0)[0][indexes][0] == flag"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "indexes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tmp[:, 0, indexes] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.sum(tmp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.sum(transition_matrx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.sum(flag)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "indexes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from util import util\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_random_mask(opt):\n",
    "    gMask_opts = {}\n",
    "    mask_global = torch.ByteTensor(1, 1, \\\n",
    "                                 opt.fineSize, opt.fineSize)\n",
    "\n",
    "        # Here we need to set an artificial mask_global(not to make it broken, so center hole is ok.)\n",
    "    mask_global.zero_()\n",
    "    mask_global[:, :, int(opt.fineSize/4) + opt.overlap : int(opt.fineSize/2) + int(opt.fineSize/4) - opt.overlap,\\\n",
    "                                int(opt.fineSize/4) + opt.overlap: int(opt.fineSize/2) + int(opt.fineSize/4) - opt.overlap] = 1  \n",
    "    \n",
    "    res = 0.06 # the lower it is, the more continuous the output will be. 0.01 is too small and 0.1 is too large\n",
    "    density = 0.25\n",
    "    MAX_SIZE = 300\n",
    "    maxPartition = 30\n",
    "    low_pattern = torch.rand(1, 1, int(res*MAX_SIZE), int(res*MAX_SIZE)).mul(255)\n",
    "    pattern = F.upsample(low_pattern, (MAX_SIZE, MAX_SIZE), mode='bilinear').data\n",
    "    low_pattern = None\n",
    "    pattern.div_(255)\n",
    "    pattern = torch.lt(pattern,density).byte()  # 25% 1s and 75% 0s\n",
    "    pattern = torch.squeeze(pattern).byte()\n",
    "    gMask_opts['pattern'] = pattern\n",
    "    gMask_opts['MAX_SIZE'] = MAX_SIZE\n",
    "    gMask_opts['fineSize'] = opt.fineSize\n",
    "    gMask_opts['maxPartition'] = maxPartition\n",
    "    gMask_opts['mask_global'] = mask_global\n",
    "    mask_global = util.create_gMask(gMask_opts) # create an initial random mask.   \n",
    "    return mask_global"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%time mask_global = create_random_mask(opt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mask_global.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(np.squeeze(mask_global))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mask_global"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inner_shift_triple.set_mask(mask_global=mask_global, threshold=opt.threshold, layer_to_last=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit output = inner_shift_triple.forward(x_tr.cuda())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# THE ENTIRE PROCESS IS PRETTY FAST, THE ISSUE WAS COMING FROM THE MASK GENERATOR"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# IMPLEMENT AN ACCELERATE MODULE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt = parser.opt\n",
    "opt.shift_sz = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from models.accelerated_InnerShiftTriple import AcceleratedInnerShiftTriple\n",
    "acce_inner_shift_triple = AcceleratedInnerShiftTriple(opt.threshold, opt.fixed_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "acce_inner_shift_triple.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "acce_inner_shift_triple.set_mask(mask_global=mask_global, threshold=opt.threshold, layer_to_last=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit output = acce_inner_shift_triple(x_tr.cuda())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "acce_inner_shift_triple.__dict__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('THE SPEED UP IS {} FOLD'.format(582/115))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
